{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import classifier, training\n",
    "tf.enable_eager_execution()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data preprocessing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First let's load the MNIST dataset of hand-written digits from `tensorflow`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(60000, 28, 28) (60000,)\n",
      "(10000, 28, 28) (10000,)\n"
     ]
    }
   ],
   "source": [
    "mnist = tf.keras.datasets.mnist\n",
    "(x_train, y_train), (x_test, y_test) = mnist.load_data(path='mnist.npz')\n",
    "\n",
    "print(x_train.shape, y_train.shape)\n",
    "print(x_test.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next let's encode the data using the feature map $\\Phi (p) = (p, 1-p)^T$ and transform the labels to one-hot format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10\n",
      "(60000, 784, 2) (60000, 10)\n",
      "(10000, 784, 2) (10000, 10)\n"
     ]
    }
   ],
   "source": [
    "def data_encoder(data):\n",
    "  return np.array([1 - data, data]).transpose([1, 2, 0])\n",
    "\n",
    "def to_one_hot(labels, n_labels=10):\n",
    "  one_hot = np.zeros((len(labels), n_labels))\n",
    "  one_hot[np.arange(len(labels)), labels] = 1\n",
    "  return one_hot\n",
    "\n",
    "n_labels = len(np.unique(y_train))\n",
    "\n",
    "# Flatten and normalize\n",
    "x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:]))) / 255.0\n",
    "x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:]))) / 255.0\n",
    "# Encode\n",
    "x_train = data_encoder(x_train)\n",
    "x_test = data_encoder(x_test)\n",
    "y_train = to_one_hot(y_train)\n",
    "y_test = to_one_hot(y_test)\n",
    "\n",
    "print(n_labels)\n",
    "print(x_train.shape, y_train.shape)\n",
    "print(x_test.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define MPS classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that our MPS has one more site than the data because of the label tensor. We also have to set the bond dimension which is a hyperparameter and remains constant during training. In a more sophisticated implementation the bond dimension can be adaptively changed according to the complexity of training data by performing some SVD steps. This is currently not implemented but can be added in a future version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "mps = classifier.MatrixProductState(n_sites=x_train.shape[1] + 1,\n",
    "                                    n_labels=n_labels,\n",
    "                                    d_phys=x_train.shape[2],\n",
    "                                    d_bond=12)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can train the `mps` object we created using the `training.fit` data. Here we perform a quick training in a small portion of the data without validation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch: 0\n",
      "Time: 8.074748754501343\n",
      "Loss: 2.275792121887207\n",
      "Accuracy: 0.125\n",
      "\n",
      "Epoch: 1\n",
      "Time: 16.331595182418823\n",
      "Loss: 1.8269275426864624\n",
      "Accuracy: 0.332\n",
      "\n",
      "Epoch: 2\n",
      "Time: 24.106159210205078\n",
      "Loss: 1.1868376731872559\n",
      "Accuracy: 0.652\n",
      "\n",
      "Epoch: 3\n",
      "Time: 31.7194664478302\n",
      "Loss: 0.711657702922821\n",
      "Accuracy: 0.755\n",
      "\n",
      "Epoch: 4\n",
      "Time: 39.23139977455139\n",
      "Loss: 0.5140281319618225\n",
      "Accuracy: 0.831\n",
      "\n",
      "Epoch: 5\n",
      "Time: 46.81901288032532\n",
      "Loss: 0.3754191994667053\n",
      "Accuracy: 0.88\n",
      "\n",
      "Epoch: 6\n",
      "Time: 55.55578422546387\n",
      "Loss: 0.291353315114975\n",
      "Accuracy: 0.91\n",
      "\n",
      "Epoch: 7\n",
      "Time: 64.64535021781921\n",
      "Loss: 0.27905139327049255\n",
      "Accuracy: 0.913\n",
      "\n",
      "Epoch: 8\n",
      "Time: 73.24950790405273\n",
      "Loss: 0.2664482891559601\n",
      "Accuracy: 0.913\n",
      "\n",
      "Epoch: 9\n",
      "Time: 80.5286967754364\n",
      "Loss: 0.25159114599227905\n",
      "Accuracy: 0.92\n"
     ]
    }
   ],
   "source": [
    "optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)\n",
    "\n",
    "mps, history = training.fit(mps, optimizer, x_train[:1000], y_train[:1000],\n",
    "                            n_epochs=10, batch_size=50, n_message=1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.6.12 64-bit ('tf36': conda)",
   "metadata": {
    "interpreter": {
     "hash": "b871ec03f0cd0b2b0e1e1c9c84109c250f6229452d3d8d61c70e6b0e86220b02"
    }
   },
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}